{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# Found cached data ./liberty.pt\n",
      "# Found cached data ./yosemite.pt\n",
      "# Found cached data ./notredame.pt\n",
      "# Found cached data ./yosemite.pt\n"
     ]
    }
   ],
   "source": [
    "import torchvision as tv\n",
    "import phototour\n",
    "import torch\n",
    "from tqdm import tqdm \n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import math \n",
    "\n",
    "lib_train = phototour.PhotoTour('.','liberty', download=True, train=True, mode = 'triplets', augment = True, nsamples=409600)\n",
    "yos_train = phototour.PhotoTour('.','yosemite', download=True, train=True, mode = 'triplets', augment = True)\n",
    "nd_train = phototour.PhotoTour('.','notredame', download=True, train=True, mode = 'triplets', augment = True)\n",
    "\n",
    "eval_db = phototour.PhotoTour('.','yosemite', download=True, train=False)\n",
    "# train_db = torch.utils.data.ConcatDataset((lib_train, yos_train))\n",
    "train_db = nd_train\n",
    "train_name = 'notredame'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tfeat_model\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import torch.backends.cudnn as cudnn\n",
    "\n",
    "tfeat = tfeat_model.TNet()\n",
    "tfeat = tfeat.cuda()\n",
    "\n",
    "# this kind of works\n",
    "optimizer = optim.SGD(tfeat.parameters(), lr=0.1, weight_decay=0.0001)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)\n",
    "\n",
    "seed=42\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "# cv2.setRNGSeed(seed)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_db,\n",
    "                                             batch_size=300, shuffle=False,\n",
    "                                             num_workers=30)\n",
    "\n",
    "eval_loader = torch.utils.data.DataLoader(eval_db,\n",
    "                                             batch_size=1024, shuffle=False,\n",
    "                                             num_workers=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:21,  8.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.1657\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:25,  8.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 0.12372\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:20,  8.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2 0.1015\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:26,  8.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3 0.1019\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:23,  8.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4 0.10572\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:22,  8.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5 0.09724\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:21,  8.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6 0.10236\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:21,  8.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7 0.08238\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:19,  8.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8 0.08986\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:21,  8.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9 0.09476\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:20,  8.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10 0.08732\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:20,  8.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11 0.07934\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:20,  8.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12 0.08422\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:20,  8.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "13 0.07902\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:19,  8.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "14 0.0802\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:23,  8.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "15 0.07846\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:19,  8.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16 0.07706\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:18,  8.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "17 0.0855\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:20,  8.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "18 0.07988\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:21,  8.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "19 0.079\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:22,  8.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20 0.08532\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:21,  8.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "21 0.07216\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:23,  8.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "22 0.08076\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:22,  8.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "23 0.07728\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:21,  8.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "24 0.07908\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:22,  8.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "25 0.07762\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:21,  8.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "26 0.07134\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:34,  8.44it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "27 0.07572\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:21,  8.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "28 0.07106\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:23,  8.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "29 0.07926\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:20,  8.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "30 0.07034\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:42,  8.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "31 0.07322\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:20,  8.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32 0.07258\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:32,  8.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "33 0.07342\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:24,  8.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "34 0.07228\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:21,  8.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "35 0.07366\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:20,  8.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "36 0.07084\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:19,  8.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "37 0.07102\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:22,  8.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "38 0.0727\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:23,  8.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "39 0.06934\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:19,  8.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "40 0.07024\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:18,  8.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "41 0.07024\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:24,  8.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "42 0.0699\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:19,  8.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "43 0.07\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:23,  8.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "44 0.06696\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:21,  8.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "45 0.07304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3334it [06:20,  8.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "46 0.0667\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1919it [03:43,  8.60it/s]"
     ]
    }
   ],
   "source": [
    "fpr_per_epoch = []\n",
    "\n",
    "for e in range(300):\n",
    "    tfeat.train()\n",
    "    for batch_idx, (data_a, data_p, data_n) in tqdm(enumerate(train_loader)):\n",
    "        data_a = data_a.unsqueeze(1).float().cuda()\n",
    "        data_p = data_p.unsqueeze(1).float().cuda()\n",
    "        data_n = data_n.unsqueeze(1).float().cuda()\n",
    "        out_a, out_p, out_n = tfeat(data_a), tfeat(data_p), tfeat(data_n)\n",
    "        loss = F.triplet_margin_loss(out_a, out_p, out_n, margin=2, swap=True) \n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "            \n",
    "    tfeat.eval()\n",
    "    l = np.empty((0,))\n",
    "    d = np.empty((0,))\n",
    "    #evaluate the network after each epoch\n",
    "    for batch_idx, (data_l, data_r, lbls) in enumerate(eval_loader):\n",
    "        data_l = data_l.unsqueeze(1).float().cuda()\n",
    "        data_r = data_r.unsqueeze(1).float().cuda()\n",
    "        out_l, out_r = tfeat(data_l), tfeat(data_r)\n",
    "        dists = torch.norm(out_l - out_r, 2, 1).detach().cpu().numpy()\n",
    "        l = np.hstack((l,lbls.numpy()))\n",
    "        d = np.hstack((d,dists))\n",
    "        \n",
    "    # FPR95 code from Yurun Tian\n",
    "    d = torch.from_numpy(d)\n",
    "    l = torch.from_numpy(l)\n",
    "    dist_pos = d[l==1]\n",
    "    dist_neg = d[l!=1]\n",
    "    dist_pos,indice = torch.sort(dist_pos)\n",
    "    loc_thr = int(np.ceil(dist_pos.numel() * 0.95))\n",
    "    thr = dist_pos[loc_thr]\n",
    "    fpr95 = float(dist_neg.le(thr).sum())/dist_neg.numel()\n",
    "    print(e,fpr95)\n",
    "    fpr_per_epoch.append([e,fpr95])\n",
    "    scheduler.step()\n",
    "    np.savetxt('fpr.txt', np.array(fpr_per_epoch), delimiter=',') \n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(tfeat.state_dict(), train_name+'-tfeat.params')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
